import torch
import torch.nn as nn
import torch.nn.functional as F

from models.layers import *
import torch.distributed as dist
from pytorch3d.transforms import matrix_to_quaternion, quaternion_to_matrix
from mamba_ssm.modules.mamba_simple import Mamba 


class HierFeatureExtraction(nn.Module):
    def __init__(self, args):
        super(HierFeatureExtraction, self).__init__()

        self.use_fps = args.use_fps
        self.use_weights = args.use_weights

        self.detector_1 = KeypointDetector(nsample=1024, k=64, in_channels=0, out_channels=[32,32,64], fps=self.use_fps)
        self.detector_2 = KeypointDetector(nsample=512, k=32, in_channels=64, out_channels=[64,64,128], fps=self.use_fps)
        self.detector_3 = KeypointDetector(nsample=256, k=16, in_channels=128, out_channels=[128,128,256], fps=self.use_fps)

        if args.freeze_detector:
            for p in self.parameters():
                p.requires_grad = False
        
        self.desc_extractor_1 = DescExtractor(in_channels=0, out_channels=[32,32,64], C_detector=64, desc_dim=64)
        self.desc_extractor_2 = DescExtractor(in_channels=64, out_channels=[64,64,128], C_detector=128, desc_dim=128)
        self.desc_extractor_3 = DescExtractor(in_channels=128, out_channels=[128,128,256], C_detector=256, desc_dim=256)
    
    def forward(self, points):
        xyz_1, sigmas_1, attentive_feature_1, grouped_features_1, attentive_feature_map_1 = self.detector_1(points, None)
        desc_1 = self.desc_extractor_1(grouped_features_1, attentive_feature_map_1)
        if self.use_weights:
            weights_1 = 1.0/(sigmas_1 + 1e-5)
            weights_1_mean = torch.mean(weights_1, dim=1, keepdim=True)
            weights_1 = weights_1 / weights_1_mean
            xyz_2, sigmas_2, attentive_feature_2, grouped_features_2, attentive_feature_map_2 = self.detector_2(xyz_1, attentive_feature_1, weights_1)
            desc_2 = self.desc_extractor_2(grouped_features_2, attentive_feature_map_2)

            weights_2 = 1.0/(sigmas_2 + 1e-5)
            weights_2_mean = torch.mean(weights_2, dim=1, keepdim=True)
            weights_2 = weights_2 / weights_2_mean
            xyz_3, sigmas_3, attentive_feature_3, grouped_features_3, attentive_feature_map_3 = self.detector_3(xyz_2, attentive_feature_2, weights_2)
            desc_3 = self.desc_extractor_3(grouped_features_3, attentive_feature_map_3)
        else:
            xyz_2, sigmas_2, attentive_feature_2, grouped_features_2, attentive_feature_map_2 = self.detector_2(xyz_1, attentive_feature_1)
            desc_2 = self.desc_extractor_2(grouped_features_2, attentive_feature_map_2)
            xyz_3, sigmas_3, attentive_feature_3, grouped_features_3, attentive_feature_map_3 = self.detector_3(xyz_2, attentive_feature_2)
            desc_3 = self.desc_extractor_3(grouped_features_3, attentive_feature_map_3)
        
        ret_dict = {}
        ret_dict['xyz_1'] = xyz_1
        ret_dict['xyz_2'] = xyz_2
        ret_dict['xyz_3'] = xyz_3
        ret_dict['sigmas_1'] = sigmas_1
        ret_dict['sigmas_2'] = sigmas_2
        ret_dict['sigmas_3'] = sigmas_3
        ret_dict['desc_1'] = desc_1
        ret_dict['desc_2'] = desc_2
        ret_dict['desc_3'] = desc_3

        return ret_dict

class HRegNet(nn.Module):

    def __init__(self, args):
        super(HRegNet, self).__init__()
        self.feature_extraction = HierFeatureExtraction(args)

        # Freeze pretrained features when train
        if args.freeze_feats:
            for p in self.parameters():
                p.requires_grad = False
        self.args = args
        self.coarse_corres = CoarseReg(k=args.K_C, nbr_k=args.K_N, in_channels=256, use_sim=True, use_neighbor=True, use_ge=True)
        self.fine_corres_2 = FineReg(k=args.K_C, in_channels=128)
        self.fine_corres_1 = FineReg(k=args.K_C, in_channels=64)

        self.svd_head = WeightedSVDHead()
        self.mlp_q = GatedMLP(4, 64, 64)
        self.mlp_t = GatedMLP(3, 64, 64)
        self.mlp_feature_coarse = GatedMLP(512, 128, 64)
        self.mamba_q = nn.Sequential(
                                    nn.LayerNorm(64),
                                    # nn.ReLU(),
                                    Mamba(d_model=64, d_state=16, d_conv=4, expand=2),
        )
        self.mamba_feature_coarse = nn.Sequential(
                                    nn.LayerNorm(64),
                                    # nn.ReLU(),
                                    Mamba(d_model=64, d_state=16, d_conv=4, expand=2),
        )
        self.mamba_t = nn.Sequential(
                                    nn.LayerNorm(64),
                                    # nn.ReLU(),
                                    Mamba(d_model=64, d_state=16, d_conv=4, expand=2),
        )
        self.q_dynamic_coarse = Dynamic_Trajectory_Decoder(64 + 64 + 4 + 4, 4)
        self.t_dynamic_coarse = Dynamic_Trajectory_Decoder(64 + 64 + 3 + 3, 3)
        self.q_dynamic_fine = Dynamic_Trajectory_Decoder(64 + 64 + 4 + 4, 4)
        self.t_dynamic_fine = Dynamic_Trajectory_Decoder(64 + 64 + 3 + 3, 3)
        self.init_history(args.batch_size)


    def init_history(self, batch_size):
        feature_dim = 64
        max_frames = self.args.mem_length   
        self.continue_frames = 0  

        self.his_coarse_attentive_feats = torch.zeros((batch_size, max_frames, feature_dim)).cuda()

        zero_quaternion = torch.tensor([1.0, 0.0, 0.0, 0.0], dtype=torch.float32).cuda() 
        small_translation = torch.full((batch_size, max_frames, 3), 1e-15).cuda() 

        self.q_history = zero_quaternion.repeat(batch_size, max_frames, 1)
        self.t_history = small_translation

        self.history_mask = torch.zeros((batch_size, max_frames)).cuda()
        self.continue_frames


    def update_history(self, new_state, history_state):

        unpdate_state = torch.cat([history_state[:, 1:], new_state.clone().detach()], dim=1)
        
        return unpdate_state

    def update_history_mask(self, history_mask):
        update_mask = torch.cat([history_mask[:, 1:], torch.ones((history_mask.shape[0], 1)).cuda()], dim=1)
        return update_mask
    
    def forward(self, src_points, dst_points):
        # Feature extraction
        src_feats = self.feature_extraction(src_points)
        dst_feats = self.feature_extraction(dst_points)

        # Coarse registration
        src_xyz_corres_3, src_dst_weights_3, coarse_attentive_feats = self.coarse_corres(src_feats['xyz_3'], src_feats['desc_3'], dst_feats['xyz_3'], \
            dst_feats['desc_3'], src_feats['sigmas_3'], dst_feats['sigmas_3'])
        
        R3, t3 = self.svd_head(src_feats['xyz_3'], src_xyz_corres_3, src_dst_weights_3)
        coarse_attentive_feats = self.mlp_feature_coarse(coarse_attentive_feats).unsqueeze(1)  # 512 -> 64
        if self.continue_frames > 0:
            coarse_attentive_feats = self.mamba_feature_coarse(torch.cat([self.his_coarse_attentive_feats[:,1:], coarse_attentive_feats], dim=1))[:,-1:]
        else:
            self.his_coarse_attentive_feats = self.update_history(torch.ones_like(coarse_attentive_feats, device=coarse_attentive_feats.device), self.his_coarse_attentive_feats)
            self.history_mask = self.update_history_mask(self.history_mask)
            coarse_attentive_feats = self.mamba_feature_coarse(torch.cat([self.his_coarse_attentive_feats[:,1:], coarse_attentive_feats], dim=1))[:,-1:]
        q_his_encode = self.mlp_q(self.q_history)
        t_his_encode = self.mlp_t(self.t_history)
        his_q_embed = self.mamba_q(q_his_encode)[:,-1:]
        his_t_embed = self.mamba_t(t_his_encode)[:,-1:]

        R3_q = matrix_to_quaternion(R3)
        R3_q_dynamic = self.q_dynamic_coarse(torch.cat([coarse_attentive_feats, his_q_embed, self.q_history[:,-1:], R3_q.unsqueeze(1).detach()], dim=-1)).squeeze(1)
        t3_dynamic = self.t_dynamic_coarse(torch.cat([coarse_attentive_feats, his_t_embed, self.t_history[:,-1:], t3.unsqueeze(1).detach()], dim=-1)).squeeze(1)
        R3_q_dynamic = R3_q_dynamic / (torch.sqrt(torch.sum(R3_q_dynamic**2, dim=-1, keepdim=True) + 1e-10) + 1e-10)
        R3_dynamic = quaternion_to_matrix(R3_q_dynamic)

        # Fine registration: Layer 2
        src_xyz_2_trans = torch.matmul(R3_dynamic, src_feats['xyz_2'].permute(0,2,1).contiguous()) + t3_dynamic.unsqueeze(2)
        src_xyz_2_trans = src_xyz_2_trans.permute(0,2,1).contiguous()
        src_xyz_corres_2, src_dst_weights_2 = self.fine_corres_2(src_xyz_2_trans, src_feats['desc_2'], dst_feats['xyz_2'], \
            dst_feats['desc_2'], src_feats['sigmas_2'], dst_feats['sigmas_2'])
        
        R2_, t2_ = self.svd_head(src_xyz_2_trans, src_xyz_corres_2, src_dst_weights_2)
        local_rank = dist.get_rank()
        T3 = torch.zeros(R3.shape[0],4,4).cuda(local_rank)
        T3[:,:3,:3] =  R3_dynamic
        T3[:,:3,3] = t3_dynamic
        T3[:,3,3] = 1.0
        T2_ = torch.zeros(R2_.shape[0],4,4).cuda(local_rank)
        T2_[:,:3,:3] = R2_
        T2_[:,:3,3] = t2_
        T2_[:,3,3] = 1.0
        T2 = torch.matmul(T2_, T3)
        R2 = T2[:,:3,:3]
        t2 = T2[:,:3,3]

        # # Fine registration: Layer 1
        src_xyz_1_trans = torch.matmul(R2, src_feats['xyz_1'].permute(0,2,1).contiguous()) + t2.unsqueeze(2)
        src_xyz_1_trans = src_xyz_1_trans.permute(0,2,1).contiguous()
        src_xyz_corres_1, src_dst_weights_1 = self.fine_corres_1(src_xyz_1_trans, src_feats['desc_1'], dst_feats['xyz_1'], \
            dst_feats['desc_1'], src_feats['sigmas_1'], dst_feats['sigmas_1'])
        R1_, t1_ = self.svd_head(src_xyz_1_trans, src_xyz_corres_1, src_dst_weights_1)
        T1_ = torch.zeros(R1_.shape[0],4,4).cuda(local_rank)
        T1_[:,:3,:3] = R1_
        T1_[:,:3,3] = t1_
        T1_[:,3,3] = 1.0

        T1 = torch.matmul(T1_, T2)
        R1 = T1[:,:3,:3]
        t1 = T1[:,:3,3]
        R1_q = matrix_to_quaternion(R1)
        R1_q_dynamic = self.q_dynamic_fine(torch.cat([coarse_attentive_feats, his_q_embed, self.q_history[:,-1:], R1_q.unsqueeze(1).detach()], dim=-1)).squeeze(1)
        t1_dynamic = self.t_dynamic_fine(torch.cat([coarse_attentive_feats, his_t_embed, self.t_history[:,-1:], t1.unsqueeze(1).detach()], dim=-1)).squeeze(1)
        R1_q_dynamic = R1_q_dynamic / (torch.sqrt(torch.sum(R1_q_dynamic**2, dim=-1, keepdim=True) + 1e-10) + 1e-10)
        R1_dynamic = quaternion_to_matrix(R1_q_dynamic)
        self.q_history = self.update_history(R1_q_dynamic.unsqueeze(1), self.q_history)
        self.t_history = self.update_history(t1_dynamic.unsqueeze(1), self.t_history)
        self.his_coarse_attentive_feats = self.update_history(coarse_attentive_feats, self.his_coarse_attentive_feats)
        self.history_mask = self.update_history_mask(self.history_mask)
        self.continue_frames += 1

        ret_dict = {}
        ret_dict['rotation'] = [R3, R2, R1, R3_dynamic, R1_dynamic]
        ret_dict['translation'] = [t3, t2, t1, t3_dynamic, t1_dynamic]
        ret_dict['src_feats'] = src_feats
        ret_dict['dst_feats'] = dst_feats

        return ret_dict